from datetime import datetime
from azure.cosmos.aio import CosmosClient
from azure.cosmos import exceptions

from .database_client_base import DatabaseClientBase


class CosmosConversationClient(DatabaseClientBase):

    def __init__(
        self,
        cosmosdb_endpoint: str,
        credential: any,
        database_name: str,
        container_name: str,
        enable_message_feedback: bool = False,
    ):
        self.cosmosdb_endpoint = cosmosdb_endpoint
        self.credential = credential
        self.database_name = database_name
        self.container_name = container_name
        self.enable_message_feedback = enable_message_feedback
        try:
            self.cosmosdb_client = CosmosClient(
                self.cosmosdb_endpoint, credential=credential
            )
        except exceptions.CosmosHttpResponseError as e:
            if e.status_code == 401:
                raise ValueError("Invalid credentials") from e
            else:
                raise ValueError("Invalid CosmosDB endpoint") from e

        try:
            self.database_client = self.cosmosdb_client.get_database_client(
                database_name
            )
        except exceptions.CosmosResourceNotFoundError:
            raise ValueError("Invalid CosmosDB database name")

        try:
            self.container_client = self.database_client.get_container_client(
                container_name
            )
        except exceptions.CosmosResourceNotFoundError:
            raise ValueError("Invalid CosmosDB container name")

    async def connect(self):
        pass

    async def close(self):
        pass

    async def ensure(self):
        if (
            not self.cosmosdb_client
            or not self.database_client
            or not self.container_client
        ):
            return False, "CosmosDB client not initialized correctly"
        try:
            await self.database_client.read()
        except Exception:
            return (
                False,
                f"CosmosDB database {self.database_name} on account {self.cosmosdb_endpoint} not found",
            )

        try:
            await self.container_client.read()
        except Exception:
            return False, f"CosmosDB container {self.container_name} not found"

        return True, "CosmosDB client initialized successfully"

    async def create_conversation(self, user_id, conversation_id, title=""):
        conversation = {
            "id": conversation_id,
            "type": "conversation",
            "createdAt": datetime.utcnow().isoformat(),
            "updatedAt": datetime.utcnow().isoformat(),
            "userId": user_id,
            "title": title,
            "conversationId": conversation_id,
        }
        # TODO: add some error handling based on the output of the upsert_item call
        resp = await self.container_client.upsert_item(conversation)
        if resp:
            return resp
        else:
            return False

    async def upsert_conversation(self, conversation):
        resp = await self.container_client.upsert_item(conversation)
        if resp:
            return resp
        else:
            return False

    async def delete_conversation(self, user_id, conversation_id):
        conversation = await self.container_client.read_item(
            item=conversation_id, partition_key=user_id
        )
        if conversation:
            resp = await self.container_client.delete_item(
                item=conversation_id, partition_key=user_id
            )
            return resp
        else:
            return True

    async def delete_messages(self, conversation_id, user_id):
        # get a list of all the messages in the conversation
        messages = await self.get_messages(user_id, conversation_id)
        response_list = []
        if messages:
            for message in messages:
                resp = await self.container_client.delete_item(
                    item=message["id"], partition_key=user_id
                )
                response_list.append(resp)
            return response_list

    async def get_conversations(self, user_id, limit, sort_order="DESC", offset=0):
        parameters = [{"name": "@userId", "value": user_id}]
        query = f"SELECT * FROM c where c.userId = @userId and c.type='conversation' order by c.updatedAt {sort_order}"
        if limit is not None:
            query += f" offset {offset} limit {limit}"

        conversations = []
        async for item in self.container_client.query_items(
            query=query, parameters=parameters
        ):
            conversations.append(item)

        return conversations

    async def get_conversation(self, user_id, conversation_id):
        parameters = [
            {"name": "@conversationId", "value": conversation_id},
            {"name": "@userId", "value": user_id},
        ]
        query = "SELECT * FROM c where c.id = @conversationId and c.type='conversation' and c.userId = @userId"
        conversations = []
        async for item in self.container_client.query_items(
            query=query, parameters=parameters
        ):
            conversations.append(item)

        # if no conversations are found, return None
        if len(conversations) == 0:
            return None
        else:
            return conversations[0]

    async def create_message(self, uuid, conversation_id, user_id, input_message: dict):
        message = {
            "id": uuid,
            "type": "message",
            "userId": user_id,
            "createdAt": datetime.utcnow().isoformat(),
            "updatedAt": datetime.utcnow().isoformat(),
            "conversationId": conversation_id,
            "role": input_message["role"],
            "content": input_message["content"],
        }

        if self.enable_message_feedback:
            message["feedback"] = ""

        resp = await self.container_client.upsert_item(message)
        if resp:
            # update the parent conversations's updatedAt field with the current message's createdAt datetime value
            conversation = await self.get_conversation(user_id, conversation_id)
            if not conversation:
                return "Conversation not found"
            conversation["updatedAt"] = message["createdAt"]
            await self.upsert_conversation(conversation)
            return resp
        else:
            return False

    async def update_message_feedback(self, user_id, message_id, feedback):
        message = await self.container_client.read_item(
            item=message_id, partition_key=user_id
        )
        if message:
            message["feedback"] = feedback
            resp = await self.container_client.upsert_item(message)
            return resp
        else:
            return False

    async def get_messages(self, user_id, conversation_id):
        parameters = [
            {"name": "@conversationId", "value": conversation_id},
            {"name": "@userId", "value": user_id},
        ]
        query = "SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type='message' AND c.userId = @userId ORDER BY c.timestamp ASC"
        messages = []
        async for item in self.container_client.query_items(
            query=query, parameters=parameters
        ):
            messages.append(item)

        return messages
